package dns

import (
	"context"
	"encoding/hex"
	"errors"
	"fmt"
	"log/slog"
	"net"
	"regexp"
	"slices"
	"strings"
)

var (
	DNSLookupAddr = net.LookupAddr
	DNSLookupHost = net.LookupHost
)

type Dns struct {
	cache *DnsCache
	ctx   context.Context
}

func New(ctx context.Context, cache *DnsCache) *Dns {
	return &Dns{
		cache: cache,
		ctx:   ctx,
	}
}

// ReverseDNS performs a reverse DNS lookup for the given IP address and trims the trailing dot from the results.
func (d *Dns) ReverseDNS(addr string) ([]string, error) {
	slog.Debug("DNS: performing reverse lookup", "addr", addr)

	if cached, ok := d.getCachedReverse(addr); ok {
		return cached, nil
	}

	names, err := DNSLookupAddr(addr)
	if err != nil {
		if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound {
			slog.Debug("DNS: no PTR record found", "addr", addr)
			return []string{}, nil
		}
		slog.Error("DNS: reverse lookup failed", "addr", addr, "err", err)
		return nil, err
	}

	slog.Debug("DNS: reverse lookup successful", "addr", addr, "names", names)

	trimmedNames := make([]string, len(names))
	for i, name := range names {
		trimmedNames[i] = strings.TrimSuffix(name, ".")
	}
	d.reverseCachePut(addr, trimmedNames)

	return trimmedNames, nil
}

// LookupHost performs a forward DNS lookup for the given hostname.
func (d *Dns) LookupHost(host string) ([]string, error) {
	slog.Debug("DNS: performing forward lookup", "host", host)

	if cached, ok := d.getCachedForward(host); ok {
		return cached, nil
	}

	addrs, err := DNSLookupHost(host)
	if err != nil {
		if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound {
			slog.Debug("DNS: no A/AAAA record found", "host", host)
			return []string{}, nil
		}
		slog.Error("DNS: forward lookup failed", "host", host, "err", err)
		return nil, err
	}

	slog.Debug("DNS: forward lookup successful", "host", host, "addrs", addrs)
	d.forwardCachePut(host, addrs)
	return addrs, nil
}

// verifyFCrDNSInternal performs the second half of the FCrDNS check, using a
// pre-fetched list of names to perform the forward lookups.
func (d *Dns) verifyFCrDNSInternal(addr string, names []string) bool {
	for _, name := range names {
		if cached, err := d.LookupHost(name); err == nil {
			if slices.Contains(cached, addr) {
				slog.Info("DNS: forward lookup confirmed original IP", "name", name, "addr", addr)
				return true
			}
			continue
		}
	}

	slog.Info("DNS: could not confirm original IP in forward lookups", "addr", addr)
	return false
}

// VerifyFCrDNS performs a forward-confirmed reverse DNS (FCrDNS) lookup for the given IP address,
// optionally matching against a provided pattern.
func (d *Dns) VerifyFCrDNS(addr string, pattern *string) bool {
	var patternVal string
	if pattern != nil {
		patternVal = *pattern
	}
	slog.Debug("DNS: performing FCrDNS lookup", "addr", addr, "pattern", patternVal)

	names, err := d.ReverseDNS(addr)
	if err != nil {
		return false
	}
	if len(names) == 0 {
		return pattern == nil // If no pattern specified, check is passed
	}

	// If a pattern is provided, check for a match.
	if pattern != nil {
		anyNameMatched := false
		for _, name := range names {
			matched, err := regexp.MatchString(*pattern, name)
			if err != nil {
				slog.Error("DNS: verifyFCrDNS invalid regex pattern", "err", err)
				return false // Invalid pattern is a failure.
			}
			if matched {
				anyNameMatched = true
				break
			}
		}

		if !anyNameMatched {
			slog.Debug("DNS: FCrDNS no PTR matches the pattern", "addr", addr, "pattern", *pattern)
			return false
		}
		slog.Debug("DNS: FCrDNS PTR matched pattern, proceeding with forward check", "addr", addr, "pattern", *pattern)
	}

	// If we're here, either there was no pattern, or the pattern matched.
	// Proceed with the forward lookup confirmation.
	return d.verifyFCrDNSInternal(addr, names)
}

// ArpaReverseIP performs translation from ip v4/v6 to arpa reverse notation
func (d *Dns) ArpaReverseIP(addr string) (string, error) {
	ip := net.ParseIP(addr)
	if ip == nil {
		return addr, errors.New("invalid IP address")
	}

	if ipv4 := ip.To4(); ipv4 != nil {
		return fmt.Sprintf("%d.%d.%d.%d", ipv4[3], ipv4[2], ipv4[1], ipv4[0]), nil
	}

	ipv6 := ip.To16()
	if ipv6 == nil {
		return addr, errors.New("invalid IPv6 address")
	}

	hexBytes := make([]byte, hex.EncodedLen(len(ipv6)))
	hex.Encode(hexBytes, ipv6)

	var sb strings.Builder
	sb.Grow(len(hexBytes)*2 - 1)

	for i := len(hexBytes) - 1; i >= 0; i-- {
		sb.WriteByte(hexBytes[i])
		if i > 0 {
			sb.WriteByte('.')
		}
	}
	return sb.String(), nil
}
